# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from datetime import timedelta

import pytest
from azure.core import MatchConditions
from azure.core.exceptions import HttpResponseError
from azure.search.documents.indexes.models import (
    AnalyzeTextOptions,
    CorsOptions,
    FreshnessScoringFunction,
    FreshnessScoringParameters,
    SearchField,
    SearchIndex,
    ScoringFunctionAggregation,
    ScoringProfile,
    SimpleField,
    SearchFieldDataType,
)
from azure.search.documents.indexes import SearchIndexClient
from devtools_testutils import AzureRecordedTestCase, recorded_by_proxy, get_credential
from search_service_preparer import SearchEnvVarPreparer, search_decorator


class TestSearchIndexClient(AzureRecordedTestCase):
    @SearchEnvVarPreparer()
    @search_decorator(schema=None, index_batch=None)
    @recorded_by_proxy
    def test_search_index_client(self, endpoint, index_name):
        client = SearchIndexClient(endpoint, get_credential(), retry_backoff_factor=60)
        index_name = "hotels"
        self._test_get_service_statistics(client)
        self._test_list_indexes_empty(client)
        self._test_create_index(client, index_name)
        self._test_list_indexes(client, index_name)
        self._test_get_index(client, index_name)
        self._test_get_index_statistics(client, index_name)
        self._test_delete_indexes_if_unchanged(client)
        self._test_create_or_update_index(client)
        self._test_create_or_update_indexes_if_unchanged(client)
        self._test_analyze_text(client, index_name)
        self._test_delete_indexes(client)

    def _test_get_service_statistics(self, client):
        result = client.get_service_statistics()
        assert isinstance(result, dict)
        assert set(result.keys()) == {"counters", "indexers_runtime", "limits"}

    def _test_list_indexes_empty(self, client):
        result = client.list_indexes()
        with pytest.raises(StopIteration):
            next(result)

    def _test_list_indexes(self, client, index_name):
        result = client.list_indexes()
        first = next(result)
        assert first.name == index_name

        with pytest.raises(StopIteration):
            next(result)

    def _test_get_index(self, client, index_name):
        result = client.get_index(index_name)
        assert result.name == index_name

    def _test_get_index_statistics(self, client, index_name):
        result = client.get_index_statistics(index_name)
        keys = set(result.keys())
        assert "document_count" in keys
        assert "storage_size" in keys
        assert "vector_index_size" in keys

    def _test_create_index(self, client, index_name):
        fields = [
            SimpleField(name="hotelId", type=SearchFieldDataType.String, key=True),
            SimpleField(name="baseRate", type=SearchFieldDataType.Double),
        ]
        scoring_profile = ScoringProfile(name="MyProfile")
        scoring_profiles = []
        scoring_profiles.append(scoring_profile)
        cors_options = CorsOptions(allowed_origins=["*"], max_age_in_seconds=60)
        index = SearchIndex(
            name=index_name,
            fields=fields,
            scoring_profiles=scoring_profiles,
            cors_options=cors_options,
        )
        result = client.create_index(index)
        assert result.name == index_name
        assert result.scoring_profiles[0].name == scoring_profile.name
        assert result.cors_options.allowed_origins == cors_options.allowed_origins
        assert result.cors_options.max_age_in_seconds == cors_options.max_age_in_seconds

    def _test_create_or_update_index(self, client):
        name = "hotels-cou"
        fields = [
            SimpleField(name="hotelId", type=SearchFieldDataType.String, key=True),
            SimpleField(name="baseRate", type=SearchFieldDataType.Double),
        ]
        cors_options = CorsOptions(allowed_origins=["*"], max_age_in_seconds=60)
        scoring_profiles = []
        index = SearchIndex(
            name=name,
            fields=fields,
            scoring_profiles=scoring_profiles,
            cors_options=cors_options,
        )
        result = client.create_or_update_index(index=index)
        assert len(result.scoring_profiles) == 0
        assert result.cors_options.allowed_origins == cors_options.allowed_origins
        assert result.cors_options.max_age_in_seconds == cors_options.max_age_in_seconds
        scoring_profile = ScoringProfile(name="MyProfile")
        scoring_profiles = []
        scoring_profiles.append(scoring_profile)
        index = SearchIndex(
            name=name,
            fields=fields,
            scoring_profiles=scoring_profiles,
            cors_options=cors_options,
        )
        result = client.create_or_update_index(index=index)
        assert result.scoring_profiles[0].name == scoring_profile.name
        assert result.cors_options.allowed_origins == cors_options.allowed_origins
        assert result.cors_options.max_age_in_seconds == cors_options.max_age_in_seconds

    def _test_create_or_update_indexes_if_unchanged(self, client):
        # First create an index
        name = "hotels-coa-unchanged"
        fields = [
            {"name": "hotelId", "type": "Edm.String", "key": True, "searchable": False},
            {"name": "baseRate", "type": "Edm.Double"},
        ]
        scoring_profile = ScoringProfile(name="MyProfile")
        scoring_profiles = []
        scoring_profiles.append(scoring_profile)
        cors_options = CorsOptions(allowed_origins=["*"], max_age_in_seconds=60)
        index = SearchIndex(
            name=name,
            fields=fields,
            scoring_profiles=scoring_profiles,
            cors_options=cors_options,
        )
        result = client.create_index(index)
        etag = result.e_tag
        # get e tag  and update
        index.scoring_profiles = []
        client.create_or_update_index(index)

        index.e_tag = etag
        with pytest.raises(HttpResponseError):
            client.create_or_update_index(
                index, match_condition=MatchConditions.IfNotModified
            )

    def _test_analyze_text(self, client, index_name):
        analyze_request = AnalyzeTextOptions(
            text="One's <two/>", analyzer_name="standard.lucene"
        )
        result = client.analyze_text(index_name, analyze_request)
        assert len(result.tokens) == 2

    def _test_delete_indexes_if_unchanged(self, client):
        # First create an index
        name = "hotels-del-unchanged"
        fields = [
            {"name": "hotelId", "type": "Edm.String", "key": True, "searchable": False},
            {"name": "baseRate", "type": "Edm.Double"},
        ]
        scoring_profile = ScoringProfile(name="MyProfile")
        scoring_profiles = []
        scoring_profiles.append(scoring_profile)
        cors_options = CorsOptions(allowed_origins=["*"], max_age_in_seconds=60)
        index = SearchIndex(
            name=name,
            fields=fields,
            scoring_profiles=scoring_profiles,
            cors_options=cors_options,
        )
        result = client.create_index(index)
        etag = result.e_tag
        # get e tag  and update
        index.scoring_profiles = []
        client.create_or_update_index(index)

        index.e_tag = etag
        with pytest.raises(HttpResponseError):
            client.delete_index(index, match_condition=MatchConditions.IfNotModified)

    def _test_delete_indexes(self, client):
        for index in client.list_indexes():
            client.delete_index(index)

    @SearchEnvVarPreparer()
    @recorded_by_proxy
    def test_purview_enabled_index(self, search_service_endpoint, search_service_name):
        del search_service_name  # unused
        endpoint = search_service_endpoint
        client = SearchIndexClient(endpoint, get_credential(), retry_backoff_factor=60)

        index_name = self.get_resource_name("purview-index")
        fields = [
            SearchField(
                name="id",
                type=SearchFieldDataType.String,
                key=True,
                filterable=True,
                sortable=True,
            ),
            SearchField(
                name="sensitivityLabel",
                type=SearchFieldDataType.String,
                filterable=True,
                sensitivity_label=True,
            ),
        ]
        index = SearchIndex(name=index_name, fields=fields, purview_enabled=True)

        created = client.create_index(index)
        try:
            assert created.purview_enabled is True
            for field in created.fields:
                if field.name == "sensitivityLabel":
                    assert field.sensitivity_label is True
                    break
            else:
                raise AssertionError("Expected sensitivityLabel field to be present")

            fetched = client.get_index(index_name)
            assert fetched.purview_enabled is True
            for field in fetched.fields:
                if field.name == "sensitivityLabel":
                    assert field.sensitivity_label is True
                    break
            else:
                raise AssertionError("Expected sensitivityLabel field to be present")
        finally:
            try:
                client.delete_index(index_name)
            except HttpResponseError:
                pass

    @SearchEnvVarPreparer()
    @recorded_by_proxy
    def test_scoring_profile_product_aggregation(
        self, search_service_endpoint, search_service_name
    ):
        del search_service_name  # unused
        endpoint = search_service_endpoint
        client = SearchIndexClient(endpoint, get_credential(), retry_backoff_factor=60)

        index_name = self.get_resource_name("agg-product")
        fields = [
            SimpleField(name="hotelId", type=SearchFieldDataType.String, key=True),
            SimpleField(
                name="lastUpdated",
                type=SearchFieldDataType.DateTimeOffset,
                filterable=True,
            ),
        ]
        scoring_profile = ScoringProfile(
            name="product-score",
            function_aggregation=ScoringFunctionAggregation.PRODUCT,
            functions=[
                FreshnessScoringFunction(
                    field_name="lastUpdated",
                    boost=2.5,
                    parameters=FreshnessScoringParameters(
                        boosting_duration=timedelta(days=7)
                    ),
                )
            ],
        )
        index = SearchIndex(
            name=index_name, fields=fields, scoring_profiles=[scoring_profile]
        )

        created = client.create_index(index)
        try:
            assert (
                created.scoring_profiles[0].function_aggregation
                == ScoringFunctionAggregation.PRODUCT
            )

            fetched = client.get_index(index_name)
            assert (
                fetched.scoring_profiles[0].function_aggregation
                == ScoringFunctionAggregation.PRODUCT
            )

            fetched.scoring_profiles[0].function_aggregation = (
                ScoringFunctionAggregation.SUM
            )
            client.create_or_update_index(index=fetched)

            updated = client.get_index(index_name)
            assert (
                updated.scoring_profiles[0].function_aggregation
                == ScoringFunctionAggregation.SUM
            )
        finally:
            try:
                client.delete_index(index_name)
            except HttpResponseError:
                pass
